{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Make sure run this notebook in xlnet repo folder after git clone,\n",
    "\n",
    "```bash\n",
    "git clone https://github.com/zihangdai/xlnet.git\n",
    "cd electra\n",
    "jupyter notebook\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# !wget https://storage.googleapis.com/xlnet/released_models/cased_L-12_H-768_A-12.zip -O xlnet.zip\n",
    "# !unzip xlnet.zip"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = '2'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sentencepiece as spm\n",
    "from prepro_utils import preprocess_text, encode_ids\n",
    "\n",
    "sp_model = spm.SentencePieceProcessor()\n",
    "sp_model.Load('xlnet_cased_L-12_H-768_A-12/spiece.model')\n",
    "\n",
    "def tokenize_fn(text):\n",
    "    text = preprocess_text(text, lower= False)\n",
    "    return encode_ids(sp_model, text)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "SEG_ID_A   = 0\n",
    "SEG_ID_B   = 1\n",
    "SEG_ID_CLS = 2\n",
    "SEG_ID_SEP = 3\n",
    "SEG_ID_PAD = 4\n",
    "\n",
    "special_symbols = {\n",
    "    \"<unk>\"  : 0,\n",
    "    \"<s>\"    : 1,\n",
    "    \"</s>\"   : 2,\n",
    "    \"<cls>\"  : 3,\n",
    "    \"<sep>\"  : 4,\n",
    "    \"<pad>\"  : 5,\n",
    "    \"<mask>\" : 6,\n",
    "    \"<eod>\"  : 7,\n",
    "    \"<eop>\"  : 8,\n",
    "}\n",
    "\n",
    "VOCAB_SIZE = 32000\n",
    "UNK_ID = special_symbols[\"<unk>\"]\n",
    "CLS_ID = special_symbols[\"<cls>\"]\n",
    "SEP_ID = special_symbols[\"<sep>\"]\n",
    "MASK_ID = special_symbols[\"<mask>\"]\n",
    "EOD_ID = special_symbols[\"<eod>\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "dict_keys(['train_X', 'train_Y', 'test_X', 'test_Y'])"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import json\n",
    "\n",
    "with open('../text.json') as fopen:\n",
    "    data = json.load(fopen)\n",
    "    \n",
    "data.keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm import tqdm\n",
    "\n",
    "MAX_SEQ_LENGTH = 120\n",
    "\n",
    "def _truncate_seq_pair(tokens_a, tokens_b, max_length):\n",
    "    while True:\n",
    "        total_length = len(tokens_a) + len(tokens_b)\n",
    "        if total_length <= max_length:\n",
    "            break\n",
    "        if len(tokens_a) > len(tokens_b):\n",
    "              tokens_a.pop()\n",
    "        else:\n",
    "              tokens_b.pop()\n",
    "                \n",
    "def get_data(left, right):\n",
    "    input_ids, input_mask, all_seg_ids = [], [], []\n",
    "    for i in tqdm(range(len(left))):\n",
    "        tokens = tokenize_fn(left[i])\n",
    "        tokens_right = tokenize_fn(right[i])\n",
    "        \n",
    "        _truncate_seq_pair(tokens, tokens_right, MAX_SEQ_LENGTH - 3)\n",
    "        segment_ids = [SEG_ID_A] * len(tokens)\n",
    "        tokens.append(SEP_ID)\n",
    "        segment_ids.append(SEG_ID_A)\n",
    "\n",
    "        tokens.extend(tokens_right)\n",
    "        segment_ids.extend([SEG_ID_B] * len(tokens_right))\n",
    "        tokens.append(SEP_ID)\n",
    "        segment_ids.append(SEG_ID_B)\n",
    "\n",
    "        tokens.append(CLS_ID)\n",
    "        segment_ids.append(SEG_ID_CLS)\n",
    "\n",
    "        cur_input_ids = tokens\n",
    "        cur_input_mask = [0] * len(cur_input_ids)\n",
    "        assert len(tokens) == len(cur_input_mask)\n",
    "        assert len(tokens) == len(segment_ids)\n",
    "        input_ids.append(tokens)\n",
    "        input_mask.append(cur_input_mask)\n",
    "        all_seg_ids.append(segment_ids)\n",
    "    return input_ids, input_mask, all_seg_ids"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 261802/261802 [01:14<00:00, 3504.24it/s]\n"
     ]
    }
   ],
   "source": [
    "left, right = [], []\n",
    "for i in range(len(data['train_X'])):\n",
    "    l, r = data['train_X'][i].split(' <> ')\n",
    "    left.append(l)\n",
    "    right.append(r)\n",
    "    \n",
    "train_ids, train_masks, segment_train = get_data(left, right)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 13395/13395 [00:03<00:00, 4423.89it/s]\n"
     ]
    }
   ],
   "source": [
    "left, right = [], []\n",
    "for i in range(len(data['test_X'])):\n",
    "    l, r = data['test_X'][i].split(' <> ')\n",
    "    left.append(l)\n",
    "    right.append(r)\n",
    "    \n",
    "test_ids, test_masks, segment_test = get_data(left, right)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/text-similarity/xlnet/model_utils.py:295: The name tf.train.Optimizer is deprecated. Please use tf.compat.v1.train.Optimizer instead.\n",
      "\n",
      "WARNING:tensorflow:From /home/husein/text-similarity/xlnet/xlnet.py:63: The name tf.gfile.Open is deprecated. Please use tf.io.gfile.GFile instead.\n",
      "\n"
     ]
    }
   ],
   "source": [
    "import xlnet\n",
    "import model_utils\n",
    "import tensorflow as tf\n",
    "import numpy as np\n",
    "\n",
    "kwargs = dict(\n",
    "      is_training=True,\n",
    "      use_tpu=False,\n",
    "      use_bfloat16=False,\n",
    "      dropout=0.1,\n",
    "      dropatt=0.1,\n",
    "      init='normal',\n",
    "      init_range=0.1,\n",
    "      init_std=0.05,\n",
    "      clamp_len=-1)\n",
    "\n",
    "xlnet_parameters = xlnet.RunConfig(**kwargs)\n",
    "xlnet_config = xlnet.XLNetConfig(json_path='xlnet_cased_L-12_H-768_A-12/xlnet_config.json')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "245439 24543\n"
     ]
    }
   ],
   "source": [
    "epoch = 15\n",
    "batch_size = 16\n",
    "warmup_proportion = 0.1\n",
    "num_train_steps = int(len(train_ids) / batch_size * epoch)\n",
    "num_warmup_steps = int(num_train_steps * warmup_proportion)\n",
    "print(num_train_steps, num_warmup_steps)\n",
    "\n",
    "training_parameters = dict(\n",
    "      decay_method = 'poly',\n",
    "      train_steps = num_train_steps,\n",
    "      learning_rate = 2e-5,\n",
    "      warmup_steps = num_warmup_steps,\n",
    "      min_lr_ratio = 0.0,\n",
    "      weight_decay = 0.00,\n",
    "      adam_epsilon = 1e-8,\n",
    "      num_core_per_host = 1,\n",
    "      lr_layer_decay_rate = 1,\n",
    "      use_tpu=False,\n",
    "      use_bfloat16=False,\n",
    "      dropout=0.0,\n",
    "      dropatt=0.0,\n",
    "      init='normal',\n",
    "      init_range=0.1,\n",
    "      init_std=0.02,\n",
    "      clip = 1.0,\n",
    "      clamp_len=-1,)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Parameter:\n",
    "    def __init__(self, decay_method, warmup_steps, weight_decay, adam_epsilon, \n",
    "                num_core_per_host, lr_layer_decay_rate, use_tpu, learning_rate, train_steps,\n",
    "                min_lr_ratio, clip, **kwargs):\n",
    "        self.decay_method = decay_method\n",
    "        self.warmup_steps = warmup_steps\n",
    "        self.weight_decay = weight_decay\n",
    "        self.adam_epsilon = adam_epsilon\n",
    "        self.num_core_per_host = num_core_per_host\n",
    "        self.lr_layer_decay_rate = lr_layer_decay_rate\n",
    "        self.use_tpu = use_tpu\n",
    "        self.learning_rate = learning_rate\n",
    "        self.train_steps = train_steps\n",
    "        self.min_lr_ratio = min_lr_ratio\n",
    "        self.clip = clip\n",
    "        \n",
    "training_parameters = Parameter(**training_parameters)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "768"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "xlnet_config.d_model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Model:\n",
    "    def __init__(\n",
    "        self,\n",
    "        dimension_output,\n",
    "        learning_rate = 2e-5,\n",
    "    ):\n",
    "        self.X = tf.placeholder(tf.int32, [None, None])\n",
    "        self.segment_ids = tf.placeholder(tf.int32, [None, None])\n",
    "        self.input_masks = tf.placeholder(tf.float32, [None, None])\n",
    "        self.Y = tf.placeholder(tf.int32, [None, None])\n",
    "        self.batch_size = tf.shape(self.X)[0]\n",
    "        \n",
    "        xlnet_model = xlnet.XLNetModel(\n",
    "            xlnet_config=xlnet_config,\n",
    "            run_config=xlnet_parameters,\n",
    "            input_ids=tf.transpose(self.X, [1, 0]),\n",
    "            seg_ids=tf.transpose(self.segment_ids, [1, 0]),\n",
    "            input_mask=tf.transpose(self.input_masks, [1, 0]))\n",
    "        \n",
    "        summary = xlnet_model.get_pooled_out(\"last\", True)\n",
    "        self.out = tf.layers.dense(summary, xlnet_config.d_model)\n",
    "        self.out = tf.nn.l2_normalize(self.out, 1)\n",
    "        self.logits = tf.layers.dense(self.out,dimension_output,use_bias=False,\n",
    "                                      kernel_constraint=tf.keras.constraints.unit_norm())\n",
    "        \n",
    "        self.gamma = 64\n",
    "        self.margin = 0.25\n",
    "        self.O_p = 1 + self.margin\n",
    "        self.O_n = -self.margin\n",
    "        self.Delta_p = 1 - self.margin\n",
    "        self.Delta_n = self.margin\n",
    "        \n",
    "        self.batch_idxs = tf.expand_dims(\n",
    "          tf.range(0, self.batch_size, dtype=tf.int32), 1)  # shape [batch,1]\n",
    "        idxs = tf.concat([self.batch_idxs, tf.cast(self.Y, tf.int32)], 1)\n",
    "        sp = tf.expand_dims(tf.gather_nd(self.logits, idxs), 1)\n",
    "        mask = tf.logical_not(\n",
    "            tf.scatter_nd(idxs, tf.ones(tf.shape(idxs)[0], tf.bool),\n",
    "                          tf.shape(self.logits)))\n",
    "\n",
    "        sn = tf.reshape(tf.boolean_mask(self.logits, mask), (self.batch_size, -1))\n",
    "\n",
    "        alpha_p = tf.nn.relu(self.O_p - tf.stop_gradient(sp))\n",
    "        alpha_n = tf.nn.relu(tf.stop_gradient(sn) - self.O_n)\n",
    "\n",
    "        r_sp_m = alpha_p * (sp - self.Delta_p)\n",
    "        r_sn_m = alpha_n * (sn - self.Delta_n)\n",
    "        _Z = tf.concat([r_sn_m, r_sp_m], 1)\n",
    "        _Z = _Z * self.gamma\n",
    "        # sum all similarity\n",
    "        logZ = tf.math.reduce_logsumexp(_Z, 1, keepdims=True)\n",
    "        # remove sn_p from all sum similarity\n",
    "        self.cost = -r_sp_m * self.gamma + logZ\n",
    "        self.cost = tf.reduce_mean(self.cost[:,0])\n",
    "        \n",
    "        self.optimizer, self.learning_rate, _ = model_utils.get_train_op(training_parameters, self.cost)\n",
    "        \n",
    "        correct_pred = tf.equal(\n",
    "            tf.argmax(self.logits, 1, output_type = tf.int32), self.Y[:,0]\n",
    "        )\n",
    "        self.accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/text-similarity/xlnet/xlnet.py:220: The name tf.variable_scope is deprecated. Please use tf.compat.v1.variable_scope instead.\n",
      "\n",
      "WARNING:tensorflow:From /home/husein/text-similarity/xlnet/xlnet.py:220: The name tf.AUTO_REUSE is deprecated. Please use tf.compat.v1.AUTO_REUSE instead.\n",
      "\n",
      "WARNING:tensorflow:From /home/husein/text-similarity/xlnet/modeling.py:453: The name tf.logging.info is deprecated. Please use tf.compat.v1.logging.info instead.\n",
      "\n",
      "INFO:tensorflow:memory input None\n",
      "INFO:tensorflow:Use float type <dtype: 'float32'>\n",
      "WARNING:tensorflow:From /home/husein/text-similarity/xlnet/modeling.py:460: The name tf.get_variable is deprecated. Please use tf.compat.v1.get_variable instead.\n",
      "\n",
      "WARNING:tensorflow:From /home/husein/text-similarity/xlnet/modeling.py:535: dropout (from tensorflow.python.layers.core) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use keras.layers.dropout instead.\n",
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensorflow_core/python/layers/core.py:271: Layer.apply (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use `layer.__call__` method instead.\n",
      "WARNING:tensorflow:\n",
      "The TensorFlow contrib module will not be included in TensorFlow 2.0.\n",
      "For more information, please see:\n",
      "  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md\n",
      "  * https://github.com/tensorflow/addons\n",
      "  * https://github.com/tensorflow/io (for I/O related ops)\n",
      "If you depend on functionality not listed there, please file an issue.\n",
      "\n",
      "WARNING:tensorflow:From /home/husein/text-similarity/xlnet/modeling.py:67: dense (from tensorflow.python.layers.core) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use keras.layers.Dense instead.\n",
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensorflow_core/python/ops/array_ops.py:1475: where (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use tf.where in 2.0, which has the same broadcast rule as np.where\n",
      "WARNING:tensorflow:From /home/husein/text-similarity/xlnet/model_utils.py:96: The name tf.train.get_or_create_global_step is deprecated. Please use tf.compat.v1.train.get_or_create_global_step instead.\n",
      "\n",
      "WARNING:tensorflow:From /home/husein/text-similarity/xlnet/model_utils.py:108: The name tf.train.polynomial_decay is deprecated. Please use tf.compat.v1.train.polynomial_decay instead.\n",
      "\n",
      "WARNING:tensorflow:From /home/husein/text-similarity/xlnet/model_utils.py:131: The name tf.train.AdamOptimizer is deprecated. Please use tf.compat.v1.train.AdamOptimizer instead.\n",
      "\n"
     ]
    }
   ],
   "source": [
    "dimension_output = 2\n",
    "learning_rate = 2e-5\n",
    "\n",
    "tf.reset_default_graph()\n",
    "sess = tf.InteractiveSession()\n",
    "model = Model(dimension_output, learning_rate)\n",
    "\n",
    "sess.run(tf.global_variables_initializer())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "import collections\n",
    "import re\n",
    "\n",
    "def get_assignment_map_from_checkpoint(tvars, init_checkpoint):\n",
    "    \"\"\"Compute the union of the current variables and checkpoint variables.\"\"\"\n",
    "    assignment_map = {}\n",
    "    initialized_variable_names = {}\n",
    "\n",
    "    name_to_variable = collections.OrderedDict()\n",
    "    for var in tvars:\n",
    "        name = var.name\n",
    "        m = re.match('^(.*):\\\\d+$', name)\n",
    "        if m is not None:\n",
    "            name = m.group(1)\n",
    "        name_to_variable[name] = var\n",
    "\n",
    "    init_vars = tf.train.list_variables(init_checkpoint)\n",
    "\n",
    "    assignment_map = collections.OrderedDict()\n",
    "    for x in init_vars:\n",
    "        (name, var) = (x[0], x[1])\n",
    "        if name not in name_to_variable:\n",
    "            continue\n",
    "        assignment_map[name] = name_to_variable[name]\n",
    "        initialized_variable_names[name] = 1\n",
    "        initialized_variable_names[name + ':0'] = 1\n",
    "\n",
    "    return (assignment_map, initialized_variable_names)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "tvars = tf.trainable_variables()\n",
    "checkpoint = 'xlnet_cased_L-12_H-768_A-12/xlnet_model.ckpt'\n",
    "assignment_map, initialized_variable_names = get_assignment_map_from_checkpoint(tvars, \n",
    "                                                                                checkpoint)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Restoring parameters from xlnet_cased_L-12_H-768_A-12/xlnet_model.ckpt\n"
     ]
    }
   ],
   "source": [
    "saver = tf.train.Saver(var_list = assignment_map)\n",
    "saver.restore(sess, checkpoint)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "labels = ['contradiction', 'entailment']\n",
    "\n",
    "train_Y = data['train_Y']\n",
    "test_Y = data['test_Y']\n",
    "\n",
    "train_Y = [labels.index(i) for i in train_Y]\n",
    "test_Y = [labels.index(i) for i in test_Y]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tensorflow.keras.preprocessing.sequence import pad_sequences\n",
    "\n",
    "batch_x = train_ids[:5]\n",
    "batch_x = pad_sequences(batch_x,padding='post')\n",
    "batch_y = train_Y[:5]\n",
    "batch_y = np.expand_dims(batch_y,1)\n",
    "batch_segments = segment_train[:5]\n",
    "batch_segments = pad_sequences(batch_segments, padding='post', value = 4)\n",
    "batch_masks = train_masks[:5]\n",
    "batch_masks = pad_sequences(batch_masks, padding='post', value = 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[0.2, 55.720844]"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sess.run([model.accuracy, model.cost],\n",
    "        feed_dict = {model.X: batch_x,\n",
    "                model.Y: batch_y,\n",
    "                model.segment_ids: batch_segments,\n",
    "                model.input_masks: batch_masks})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 16363/16363 [52:21<00:00,  5.21it/s, accuracy=1, cost=0.0359]    \n",
      "test minibatch loop: 100%|██████████| 838/838 [00:54<00:00, 15.41it/s, accuracy=1, cost=0.00785] \n",
      "train minibatch loop:   0%|          | 1/16363 [00:00<47:11,  5.78it/s, accuracy=1, cost=0.482]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 0, pass acc: 0.000000, current acc: 0.914827\n",
      "time taken: 3195.6359593868256\n",
      "epoch: 0, training loss: 10.497539, training acc: 0.848343, valid loss: 5.863732, valid acc: 0.914827\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 16363/16363 [52:02<00:00,  5.24it/s, accuracy=1, cost=0.00451]   \n",
      "test minibatch loop: 100%|██████████| 838/838 [00:53<00:00, 15.59it/s, accuracy=1, cost=0.00506] \n",
      "train minibatch loop:   0%|          | 1/16363 [00:00<47:39,  5.72it/s, accuracy=0.938, cost=4.09]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 1, pass acc: 0.914827, current acc: 0.927431\n",
      "time taken: 3176.4944434165955\n",
      "epoch: 1, training loss: 5.274645, training acc: 0.924659, valid loss: 5.358865, valid acc: 0.927431\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 16363/16363 [52:19<00:00,  5.21it/s, accuracy=1, cost=0.00229]    \n",
      "test minibatch loop: 100%|██████████| 838/838 [00:52<00:00, 15.87it/s, accuracy=1, cost=0.00256] \n",
      "train minibatch loop:   0%|          | 1/16363 [00:00<47:11,  5.78it/s, accuracy=0.938, cost=4.94]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 2, pass acc: 0.927431, current acc: 0.936083\n",
      "time taken: 3192.366697072983\n",
      "epoch: 2, training loss: 3.858834, training acc: 0.948344, valid loss: 5.224887, valid acc: 0.936083\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 16363/16363 [50:28<00:00,  5.40it/s, accuracy=1, cost=0.00073]   \n",
      "test minibatch loop: 100%|██████████| 838/838 [00:52<00:00, 15.98it/s, accuracy=1, cost=0.000478]\n",
      "train minibatch loop:   0%|          | 1/16363 [00:00<47:44,  5.71it/s, accuracy=1, cost=0.000715]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 3, pass acc: 0.936083, current acc: 0.937425\n",
      "time taken: 3081.159793138504\n",
      "epoch: 3, training loss: 2.776629, training acc: 0.964443, valid loss: 5.418724, valid acc: 0.937425\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 16363/16363 [49:28<00:00,  5.51it/s, accuracy=1, cost=0.000696]\n",
      "test minibatch loop: 100%|██████████| 838/838 [00:52<00:00, 16.08it/s, accuracy=1, cost=0.00111] \n",
      "train minibatch loop:   0%|          | 1/16363 [00:00<47:49,  5.70it/s, accuracy=1, cost=0.0011]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 4, pass acc: 0.937425, current acc: 0.937724\n",
      "time taken: 3020.8846673965454\n",
      "epoch: 4, training loss: 1.990225, training acc: 0.975364, valid loss: 5.434131, valid acc: 0.937724\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 16363/16363 [52:20<00:00,  5.21it/s, accuracy=1, cost=0.00395]    \n",
      "test minibatch loop: 100%|██████████| 838/838 [00:52<00:00, 15.88it/s, accuracy=1, cost=0.00252] \n",
      "train minibatch loop:   0%|          | 1/16363 [00:00<47:00,  5.80it/s, accuracy=1, cost=0.00254]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 5, pass acc: 0.937724, current acc: 0.940334\n",
      "time taken: 3192.947704553604\n",
      "epoch: 5, training loss: 1.449264, training acc: 0.982793, valid loss: 5.935048, valid acc: 0.940334\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 16363/16363 [50:09<00:00,  5.44it/s, accuracy=1, cost=0.00803] \n",
      "test minibatch loop: 100%|██████████| 838/838 [00:53<00:00, 15.68it/s, accuracy=1, cost=0.00222] "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "time taken: 3063.1185958385468\n",
      "epoch: 6, training loss: 1.140489, training acc: 0.987418, valid loss: 6.756607, valid acc: 0.939887\n",
      "\n",
      "break epoch:7\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "from tqdm import tqdm\n",
    "import time\n",
    "\n",
    "EARLY_STOPPING, CURRENT_CHECKPOINT, CURRENT_ACC, EPOCH = 1, 0, 0, 0\n",
    "\n",
    "while True:\n",
    "    lasttime = time.time()\n",
    "    if CURRENT_CHECKPOINT == EARLY_STOPPING:\n",
    "        print('break epoch:%d\\n' % (EPOCH))\n",
    "        break\n",
    "    train_acc, train_loss = [], []\n",
    "    test_acc, test_loss = [], []\n",
    "    \n",
    "    pbar = tqdm(\n",
    "        range(0, len(train_ids), batch_size), desc = 'train minibatch loop'\n",
    "    )\n",
    "    for i in pbar:\n",
    "        index = min(i + batch_size, len(train_ids))\n",
    "        batch_x = train_ids[i: index]\n",
    "        batch_x = pad_sequences(batch_x,padding='post')\n",
    "        batch_y = train_Y[i: index]\n",
    "        batch_y = np.expand_dims(batch_y,1)\n",
    "        batch_segments = segment_train[i: index]\n",
    "        batch_segments = pad_sequences(batch_segments, padding='post', value = 4)\n",
    "        batch_masks = train_masks[i: index]\n",
    "        batch_masks = pad_sequences(batch_masks, padding='post', value = 1)\n",
    "        acc, cost, _ = sess.run(\n",
    "            [model.accuracy, model.cost, model.optimizer],\n",
    "            feed_dict = {\n",
    "                model.X: batch_x,\n",
    "                model.Y: batch_y,\n",
    "                model.segment_ids: batch_segments,\n",
    "                model.input_masks: batch_masks\n",
    "            },\n",
    "        )\n",
    "        train_loss.append(cost)\n",
    "        train_acc.append(acc)\n",
    "        pbar.set_postfix(cost = cost, accuracy = acc)\n",
    "    \n",
    "    pbar = tqdm(\n",
    "        range(0, len(test_ids), batch_size), desc = 'test minibatch loop'\n",
    "    )\n",
    "    for i in pbar:\n",
    "        index = min(i + batch_size, len(test_ids))\n",
    "        batch_x = test_ids[i: index]\n",
    "        batch_x = pad_sequences(batch_x,padding='post')\n",
    "        batch_y = test_Y[i: index]\n",
    "        batch_y = np.expand_dims(batch_y,1)\n",
    "        batch_segments = segment_test[i: index]\n",
    "        batch_segments = pad_sequences(batch_segments, padding='post', value = 4)\n",
    "        batch_masks = test_masks[i: index]\n",
    "        batch_masks = pad_sequences(batch_masks, padding='post', value = 1)\n",
    "        \n",
    "        acc, cost = sess.run(\n",
    "            [model.accuracy, model.cost],\n",
    "            feed_dict = {\n",
    "                model.X: batch_x,\n",
    "                model.Y: batch_y,\n",
    "                model.segment_ids: batch_segments,\n",
    "                model.input_masks: batch_masks\n",
    "            },\n",
    "        )\n",
    "        test_loss.append(cost)\n",
    "        test_acc.append(acc)\n",
    "        pbar.set_postfix(cost = cost, accuracy = acc)\n",
    "    \n",
    "    train_loss = np.mean(train_loss)\n",
    "    train_acc = np.mean(train_acc)\n",
    "    test_loss = np.mean(test_loss)\n",
    "    test_acc = np.mean(test_acc)\n",
    "    \n",
    "    if test_acc > CURRENT_ACC:\n",
    "        print(\n",
    "            'epoch: %d, pass acc: %f, current acc: %f'\n",
    "            % (EPOCH, CURRENT_ACC, test_acc)\n",
    "        )\n",
    "        CURRENT_ACC = test_acc\n",
    "        CURRENT_CHECKPOINT = 0\n",
    "    else:\n",
    "        CURRENT_CHECKPOINT += 1\n",
    "        \n",
    "    print('time taken:', time.time() - lasttime)\n",
    "    print(\n",
    "        'epoch: %d, training loss: %f, training acc: %f, valid loss: %f, valid acc: %f\\n'\n",
    "        % (EPOCH, train_loss, train_acc, test_loss, test_acc)\n",
    "    )\n",
    "    EPOCH += 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
